{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5c4abc0",
   "metadata": {},
   "source": [
    "# Extension for Scikit-learn KNN for MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23512089",
   "metadata": {},
   "outputs": [],
   "source": [
    "from timeit import default_timer as timer\n",
    "from IPython.display import HTML\n",
    "from sklearn import metrics\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6e359f6",
   "metadata": {},
   "source": [
    "### Download the data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "27b99b44",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = fetch_openml(name=\"mnist_784\", return_X_y=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6259f584",
   "metadata": {},
   "source": [
    "Split the data into train and test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "96e14dd7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((56000, 784), (14000, 784), (56000,), (14000,))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=72)\n",
    "x_train.shape, x_test.shape, y_train.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0341cac9",
   "metadata": {},
   "source": [
    "### Patch original Scikit-learn with Extension for Scikit-learn\n",
    "Extension for Scikit-learn (previously known as daal4py) contains drop-in replacement functionality for the stock Scikit-learn package. You can take advantage of the performance optimizations of Extension for Scikit-learn by adding just two lines of code before the usual Scikit-learn imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "244c5bc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn\n",
    "\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bb14ac8",
   "metadata": {},
   "source": [
    "Extension for Scikit-learn patching affects performance of specific Scikit-learn functionality. Refer to the [list of supported algorithms and parameters](https://uxlfoundation.github.io/scikit-learn-intelex/latest/algorithms.html) for details. In cases when unsupported parameters are used, the package fallbacks into original Scikit-learn. If the patching does not cover your scenarios, [submit an issue on GitHub](https://github.com/uxlfoundation/scikit-learn-intelex/issues)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "693b4e26",
   "metadata": {},
   "source": [
    "Training and predict KNN algorithm with Extension for Scikit-learn for MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e9b8f06b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Extension for Scikit-learn time: 1.45 s'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "params = {\"n_neighbors\": 40, \"weights\": \"distance\", \"n_jobs\": -1}\n",
    "start = timer()\n",
    "knn = KNeighborsClassifier(**params).fit(x_train, y_train)\n",
    "predicted = knn.predict(x_test)\n",
    "time_opt = timer() - start\n",
    "f\"Extension for Scikit-learn time: {time_opt:.2f} s\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8ca549ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification report for Extension for Scikit-learn KNN:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.97      0.99      0.98      1365\n",
      "           1       0.93      0.99      0.96      1637\n",
      "           2       0.99      0.94      0.96      1401\n",
      "           3       0.96      0.95      0.96      1455\n",
      "           4       0.98      0.96      0.97      1380\n",
      "           5       0.95      0.95      0.95      1219\n",
      "           6       0.96      0.99      0.97      1317\n",
      "           7       0.94      0.95      0.95      1420\n",
      "           8       0.99      0.90      0.94      1379\n",
      "           9       0.92      0.94      0.93      1427\n",
      "\n",
      "    accuracy                           0.96     14000\n",
      "   macro avg       0.96      0.96      0.96     14000\n",
      "weighted avg       0.96      0.96      0.96     14000\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "report = metrics.classification_report(y_test, predicted)\n",
    "print(f\"Classification report for Extension for Scikit-learn KNN:\\n{report}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd8e7b0b",
   "metadata": {},
   "source": [
    "*The first column of the classification report above is the class labels.*  \n",
    "  \n",
    "### Train the same algorithm with original Scikit-learn\n",
    "In order to cancel optimizations, we use *unpatch_sklearn* and reimport the class KNeighborsClassifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5bb884d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearnex import unpatch_sklearn\n",
    "\n",
    "unpatch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cfa0dba",
   "metadata": {},
   "source": [
    "Training and predict KNN algorithm with original Scikit-learn library for MNSIT dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ae421d8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Original Scikit-learn time: 36.15 s'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "\n",
    "start = timer()\n",
    "knn = KNeighborsClassifier(**params).fit(x_train, y_train)\n",
    "predicted = knn.predict(x_test)\n",
    "time_original = timer() - start\n",
    "f\"Original Scikit-learn time: {time_original:.2f} s\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "33da9fd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification report for original Scikit-learn KNN:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.97      0.99      0.98      1365\n",
      "           1       0.93      0.99      0.96      1637\n",
      "           2       0.99      0.94      0.96      1401\n",
      "           3       0.96      0.95      0.96      1455\n",
      "           4       0.98      0.96      0.97      1380\n",
      "           5       0.95      0.95      0.95      1219\n",
      "           6       0.96      0.99      0.97      1317\n",
      "           7       0.94      0.95      0.95      1420\n",
      "           8       0.99      0.90      0.94      1379\n",
      "           9       0.92      0.94      0.93      1427\n",
      "\n",
      "    accuracy                           0.96     14000\n",
      "   macro avg       0.96      0.96      0.96     14000\n",
      "weighted avg       0.96      0.96      0.96     14000\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "report = metrics.classification_report(y_test, predicted)\n",
    "print(f\"Classification report for original Scikit-learn KNN:\\n{report}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ffd79e96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<h2>With scikit-learn-intelex patching you can:</h2><ul><li>Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);</li><li>Get comparable model quality</li><li>Get a <strong>24.9x</strong> speedup.</li></ul>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "HTML(\n",
    "    f\"<h2>With scikit-learn-intelex patching you can:</h2>\"\n",
    "    f\"<ul>\"\n",
    "    f\"<li>Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);</li>\"\n",
    "    f\"<li>Get comparable model quality</li>\"\n",
    "    f\"<li>Get a <strong>{(time_original/time_opt):.1f}x</strong> speedup.</li>\"\n",
    "    f\"</ul>\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
